{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47e81551-616e-4b1c-aa97-4cb39512a333",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/tomazbratanic/anaconda3/lib/python3.11/site-packages/pandas/core/arrays/masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
      "  from pandas.core import (\n"
     ]
    }
   ],
   "source": [
    "from neo4j import GraphDatabase\n",
    "import aiohttp\n",
    "import asyncio\n",
    "from langchain_text_splitters import TokenTextSplitter\n",
    "from datetime import datetime\n",
    "text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=56)\n",
    "\n",
    "driver = GraphDatabase.driver(\"neo4j+s://diffbot.neo4jlabs.com:7687\", auth=(\"neo4j\", \"\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "32318b6b-e6a6-4ef6-904a-3e9557dd2920",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_org_ids(top_k: int = 1000):\n",
    "    records, _, _ = driver.execute_query(\n",
    "        \"MATCH (o:Organization) WHERE o.importance IS NOT NULL RETURN o.id AS id ORDER BY o.importance DESC LIMIT toInteger($limit)\",\n",
    "    limit=top_k)\n",
    "    return [el['id'] for el in records]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "92460238-9247-49f9-8621-177364fdd7a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIFFBOT_TOKEN = \"\"\n",
    "\n",
    "async def get_articles(entity_id: str, date_after: str):\n",
    "    # Base URL for the API call\n",
    "    base_url = \"https://kg.diffbot.com/kg/v3/dql\"\n",
    "    \n",
    "    # Construct the query part\n",
    "    query = f'type:Article tags.uri:\"http://diffbot.com/entity/{entity_id}\" strict:language:\"en\" date>\"{date_after}\" sortBy:date'\n",
    "    \n",
    "    # Create the full URL with the query and token\n",
    "    url = f\"{base_url}?type=query&token={DIFFBOT_TOKEN}&query={query}&size=100\"\n",
    "    \n",
    "    # Make the GET request asynchronously\n",
    "    async with aiohttp.ClientSession() as session:\n",
    "        async with session.get(url) as response:\n",
    "            if response.status == 200:\n",
    "                return await response.json()  # Return the JSON response\n",
    "            else:\n",
    "                return None  # Handle the error case as needed\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3d45ddc9-3e67-4421-ad32-f1a4b091217f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import_query = \"\"\"\n",
    "UNWIND $data AS row\n",
    "MERGE (a:Article {id: row.diffbotUri})\n",
    "SET a.pageUrl = row.pageUrl,\n",
    "    a.title = row.title,\n",
    "    a.language = row.language,\n",
    "    a.sentiment = toFloat(row.sentiment),\n",
    "    a.date = row.date_obj,\n",
    "    a.summary = row.summary,\n",
    "    a.publisherCountry = row.publisherCountry,\n",
    "    a.publisherRegion = row.publisherRegion\n",
    "WITH a, row\n",
    "CALL (a, row) {\n",
    "UNWIND row.categories AS category\n",
    "MERGE (c:Category {name: category.name})\n",
    "MERGE (a)-[hc:HAS_CATEGORY]->(c)\n",
    "SET hc.score = toFloat(category.score)\n",
    "RETURN count(*) AS count\n",
    "}\n",
    "WITH a, row\n",
    "CALL (a, row) {\n",
    "UNWIND [el in row.tags WHERE el.uri IS NOT NULL | el ] AS tag\n",
    "MERGE (t:Tag {id: tag.uri})\n",
    "ON CREATE SET t.label = tag.label\n",
    "MERGE (a)-[ht:HAS_TAG]->(t)\n",
    "SET ht.score = toFloat(tag.score),\n",
    "    ht.sentiment = toFloat(tag.sentiment)\n",
    "RETURN count(*) AS count\n",
    "}\n",
    "WITH a, row\n",
    "CALL (a, row) {\n",
    "UNWIND row.texts as text\n",
    "MERGE (a)-[:HAS_CHUNK]->(c:Chunk {index: text.index})\n",
    "SET c.text = text.text\n",
    "RETURN count(*) AS count\n",
    "}\n",
    "WITH a, row\n",
    "CALL (a, row) {\n",
    "WITH a, row\n",
    "WHERE row.author IS NOT NULL\n",
    "MERGE (au:Author {name: row.author})\n",
    "MERGE (a)-[:HAS_AUTHOR]->(au)\n",
    "}\n",
    "WITH a, row\n",
    "CALL (a, row) {\n",
    "UNWIND [el in row.videos WHERE el.url IS NOT NULL] as video\n",
    "MERGE (v:Video {uri: video.url})\n",
    "SET v.summary = video.summary,\n",
    "    v.name = video.name\n",
    "MERGE (a)-[:HAS_VIDEO]->(v)\n",
    "RETURN count(*) AS count\n",
    "}\n",
    "RETURN count(*)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de4bb443-2cd6-4497-ae99-6ef6dc6302c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for con in [\"CREATE CONSTRAINT article_unique_id IF NOT EXISTS FOR (a:Article) REQUIRE a.id IS UNIQUE;\",\n",
    "\"CREATE CONSTRAINT category_unique_name IF NOT EXISTS FOR (c:Category) REQUIRE c.name IS UNIQUE;\",\n",
    "\"CREATE CONSTRAINT tag_unique_id IF NOT EXISTS FOR (t:Tag) REQUIRE t.id IS UNIQUE;\",\n",
    "\"CREATE CONSTRAINT author_unique_name IF NOT EXISTS FOR (au:Author) REQUIRE au.name IS UNIQUE;\",\n",
    "\"CREATE CONSTRAINT video_unique_uri IF NOT EXISTS FOR (v:Video) REQUIRE v.uri IS UNIQUE;\"]:\n",
    "    driver.execute_query(con, database_='articles')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c60015b0-415f-41d1-ade7-7d0a9e9037b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "step = 1000\n",
    "\n",
    "# Define a helper function to process the articles for a single organization\n",
    "async def process_organization(organization, date_after):\n",
    "    data = await get_articles(organization, date_after)\n",
    "    for ent in data['data']:\n",
    "        ent['entity']['texts'] = [{'text': el, 'index': i} for i, el in enumerate(text_splitter.split_text(ent['entity']['text']))]\n",
    "        ent['entity']['date_obj'] = datetime.fromtimestamp(ent['entity']['date']['timestamp'] / 1000)\n",
    "    return [ent['entity'] for ent in data['data']]\n",
    "\n",
    "async def update_articles(top_k: int = 1000, date_after: str = '2024-01-01'):\n",
    "    articles = []\n",
    "    organization_ids = get_org_ids(top_k)\n",
    "    # Gather all articles concurrently\n",
    "    org_tasks = [process_organization(org, date_after) for org in organization_ids]\n",
    "    org_results = await asyncio.gather(*org_tasks)\n",
    "    # Flatten the list of results into one list\n",
    "    for result in org_results:\n",
    "        articles.extend(result)\n",
    "    # Batch per 1000 articles\n",
    "    for index in range(0, len(articles), step):\n",
    "        batch = articles[index: index + step]\n",
    "        driver.execute_query(import_query, data=batch, database_='articles')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8fc14e8f-8d7d-4d06-93c7-e3ca25953ae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "await update_articles(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9a6ac2-54a8-4665-87a3-aec597956c7c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
